import os
import json
import argparse
import numpy as np
import pydicom
import SimpleITK as sitk
from PIL import Image
from tqdm import tqdm
from skimage import measure
import torch
import cv2
from sam2.build_sam import build_sam2_video_predictor_npz


def resize_grayscale_to_rgb_and_resize(array, image_size):
    d, h, w = array.shape
    resized_array = np.zeros((d, 3, image_size, image_size))
    for i in range(d):
        img_pil = Image.fromarray(array[i].astype(np.uint8))
        img_rgb = img_pil.convert("RGB")
        img_resized = img_rgb.resize((image_size, image_size))
        img_array = np.array(img_resized).transpose(2, 0, 1)
        resized_array[i] = img_array
    return resized_array


def getLargestCC(segmentation):
    labels = measure.label(segmentation)
    largestCC = labels == np.argmax(np.bincount(labels.flat)[1:]) + 1
    return largestCC


def read_dicom_volume(dicom_dir):
    dicom_files = sorted(
        [f for f in os.listdir(dicom_dir) if f.endswith(".dcm")],
        key=lambda x: int(os.path.splitext(x)[0])
    )

    slices = []
    slice_numbers = []

    for file in dicom_files:
        path = os.path.join(dicom_dir, file)
        try:
            ds = pydicom.dcmread(path)
            img = ds.pixel_array.astype(np.int16)
            slices.append(img)
            slice_num = int(os.path.splitext(file)[0])
            slice_numbers.append(slice_num)
        except Exception as e:
            print(f"[⚠️ WARNING] Couldn't read {file}: {e}")

    if len(slices) == 0:
        raise ValueError("No valid DICOM slices found.")

    volume = np.stack(slices, axis=0)
    return sitk.GetImageFromArray(volume), slice_numbers


torch.set_float32_matmul_precision('high')
torch.manual_seed(2024)
torch.cuda.manual_seed(2024)
np.random.seed(2024)

parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str, default='checkpoints/MedSAM2_CTLesion.pt')
parser.add_argument('--cfg', type=str, default='configs/sam2.1/sam2.1_hiera_t512.yaml')
parser.add_argument('--input_json', type=str, default='')
parser.add_argument('--pred_save_dir', type=str, default='')
args = parser.parse_args()

os.makedirs(args.pred_save_dir, exist_ok=True)
predictor = build_sam2_video_predictor_npz(args.cfg, args.checkpoint)

with open(args.input_json, 'r') as f:
    data = json.load(f)

for sample in tqdm(data):
    image_path = sample['dicom_path']
    mask_path = sample['key_slice_mask_path']
    key_slice_index = sample['key_slice_index']
    save_name = sample['series_id'] + f"_k{key_slice_index}"

    try:
        img_itk, slice_numbers = read_dicom_volume(image_path)
    except Exception as e:
        print(f"[❌ ERROR] Could not read DICOM series at {image_path}: {e}")
        continue

    slice_index_map = {sn: idx for idx, sn in enumerate(slice_numbers)}

    if key_slice_index not in slice_index_map:
        print(f"[❌ SKIP] Key slice index {key_slice_index} not in available slices: {slice_numbers}")
        continue

    key_slice_offset = slice_index_map[key_slice_index]
    img_np = sitk.GetArrayFromImage(img_itk) 

    try:
        mask_np = np.array(Image.open(mask_path).convert('L'))
        mask_np = (mask_np > 127).astype(np.uint8)
    except Exception as e:
        print(f"[❌ ERROR] Could not read mask {mask_path}: {e}")
        continue

    try:
        if mask_np.shape != img_np[key_slice_offset].shape:
            print(f"[❌ ERROR] Mask shape {mask_np.shape} does not match key slice shape {img_np[key_slice_offset].shape}")
            continue
    except Exception as e:
        print(f"[❌ EXCEPTION] Accessing slice failed: {e}")
        continue


    bbox = np.array(np.where(mask_np > 0))
    y_min, x_min = bbox.min(axis=1)
    y_max, x_max = bbox.max(axis=1)
    box = [x_min, y_min, x_max, y_max]

    img_resized = resize_grayscale_to_rgb_and_resize(img_np, 512)
    img_resized = img_resized / 255.0
    img_resized = torch.from_numpy(img_resized).cuda()
    img_mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32)[:, None, None].cuda()
    img_std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32)[:, None, None].cuda()
    img_resized = (img_resized - img_mean) / img_std

    segs_3D = np.zeros(img_np.shape, dtype=np.uint8)

    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
        inference_state = predictor.init_state(img_resized, img_np.shape[1], img_np.shape[2])
        _, _, out_mask_logits = predictor.add_new_points_or_box(
            inference_state=inference_state,
            frame_idx=key_slice_offset,
            obj_id=1,
            box=box,
        )
        for out_frame_idx, _, out_mask_logits in predictor.propagate_in_video(inference_state):
            segs_3D[out_frame_idx, (out_mask_logits[0] > 0.0).cpu().numpy()[0]] = 1
        predictor.reset_state(inference_state)

        _, _, out_mask_logits = predictor.add_new_points_or_box(
            inference_state=inference_state,
            frame_idx=key_slice_offset,
            obj_id=1,
            box=box,
        )
        for out_frame_idx, _, out_mask_logits in predictor.propagate_in_video(inference_state, reverse=True):
            segs_3D[out_frame_idx, (out_mask_logits[0] > 0.0).cpu().numpy()[0]] = 1
        predictor.reset_state(inference_state)

    if np.max(segs_3D) > 0:
        segs_3D = getLargestCC(segs_3D)
        segs_3D = np.uint8(segs_3D)


    mask_dir = os.path.join(args.pred_save_dir, save_name)
    os.makedirs(mask_dir, exist_ok=True)
    for i, slice_id in enumerate(slice_numbers):
        mask_slice = segs_3D[i]
        mask_path = os.path.join(mask_dir, f"{slice_id:03d}.png")
        cv2.imwrite(mask_path, mask_slice * 255)

print("✅ All 3D lesion masks saved successfully as PNG slices.")
